Add python API for flash-attn #558
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR provides a template for a flash-attn Python API. Currently, based on a sycl-tla kernel, I have reproduced part of flash-attn’s functionality, such as fwd and varlen_fwd.
I see that the current implementation of the kernel does not expose an external interface, and the test files are quite limited in scope. If possible, perhaps we can jointly maintain a common API, similar to what Dao-AILab/flash-attention does. Subsequent unit tests can be based on
test_flash_attn.pyto remain consistent with the CUDA official interface.During the reproduction I found some edge-case issues when running tests. As a result, I modified
xe_flash_attn_prefill_epilogue.hppandxe_flash_attn_prefill.hpp.If you are interested, we can discuss this further. If there are any issues with the current code, please let me know. Thanks!
Current method to build the API:
After that, you can run tests using the same import statements as in
test_flash_attn.py.